/* -----------------------------------------------------------------------------
 * Programmer(s): David J. Gardner @ LLNL
 * -----------------------------------------------------------------------------
 * SUNDIALS Copyright Start
 * Copyright (c) 2002-2025, Lawrence Livermore National Security
 * and Southern Methodist University.
 * All rights reserved.
 *
 * See the top-level LICENSE and NOTICE files for details.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 * SUNDIALS Copyright End
 * -----------------------------------------------------------------------------
 * This the implementation file for the IDA nonlinear solver interface.
 * ---------------------------------------------------------------------------*/

#include "idas_impl.h"
#include "sundials/sundials_math.h"
#include "sundials/sundials_nvector_senswrapper.h"

/* constant macros */
#define ONE    SUN_RCONST(1.0)  /* real 1.0    */
#define TWENTY SUN_RCONST(20.0) /* real 20.0   */

/* nonlinear solver parameters */
#define MAXIT 4 /* default max number of nonlinear iterations    */
#define RATEMAX \
  SUN_RCONST(0.9) /* max convergence rate used in divergence check */

/* private functions passed to nonlinear solver */
static int idaNlsResidualSensStg(N_Vector ycor, N_Vector res, void* ida_mem);
static int idaNlsLSetupSensStg(sunbooleantype jbad, sunbooleantype* jcur,
                               void* ida_mem);
static int idaNlsLSolveSensStg(N_Vector delta, void* ida_mem);
static int idaNlsConvTestSensStg(SUNNonlinearSolver NLS, N_Vector ycor,
                                 N_Vector del, sunrealtype tol, N_Vector ewt,
                                 void* ida_mem);

/* -----------------------------------------------------------------------------
 * Exported functions
 * ---------------------------------------------------------------------------*/

int IDASetNonlinearSolverSensStg(void* ida_mem, SUNNonlinearSolver NLS)
{
  IDAMem IDA_mem;
  int retval, is;

  /* return immediately if IDA memory is NULL */
  if (ida_mem == NULL)
  {
    IDAProcessError(NULL, IDA_MEM_NULL, __LINE__, __func__, __FILE__, MSG_NO_MEM);
    return (IDA_MEM_NULL);
  }
  IDA_mem = (IDAMem)ida_mem;

  /* return immediately if NLS memory is NULL */
  if (NLS == NULL)
  {
    IDAProcessError(NULL, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "NLS must be non-NULL");
    return (IDA_ILL_INPUT);
  }

  /* check for required nonlinear solver functions */
  if (NLS->ops->gettype == NULL || NLS->ops->solve == NULL ||
      NLS->ops->setsysfn == NULL)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "NLS does not support required operations");
    return (IDA_ILL_INPUT);
  }

  /* check for allowed nonlinear solver types */
  if (SUNNonlinSolGetType(NLS) != SUNNONLINEARSOLVER_ROOTFIND)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "NLS type must be SUNNONLINEARSOLVER_ROOTFIND");
    return (IDA_ILL_INPUT);
  }

  /* check that sensitivities were initialized */
  if (!(IDA_mem->ida_sensi))
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    MSG_NO_SENSI);
    return (IDA_ILL_INPUT);
  }

  /* check that the staggered corrector was selected */
  if (IDA_mem->ida_ism != IDA_STAGGERED)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Sensitivity solution method is not IDA_STAGGERED");
    return (IDA_ILL_INPUT);
  }

  /* free any existing nonlinear solver */
  if ((IDA_mem->NLSstg != NULL) && (IDA_mem->ownNLSstg))
  {
    retval = SUNNonlinSolFree(IDA_mem->NLSstg);
  }

  /* set SUNNonlinearSolver pointer */
  IDA_mem->NLSstg = NLS;

  /* Set NLS ownership flag. If this function was called to attach the default
     NLS, IDA will set the flag to SUNTRUE after this function returns. */
  IDA_mem->ownNLSstg = SUNFALSE;

  /* set the nonlinear residual function */
  retval = SUNNonlinSolSetSysFn(IDA_mem->NLSstg, idaNlsResidualSensStg);
  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Setting nonlinear system function failed");
    return (IDA_ILL_INPUT);
  }

  /* set convergence test function */
  retval = SUNNonlinSolSetConvTestFn(IDA_mem->NLSstg, idaNlsConvTestSensStg,
                                     ida_mem);
  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Setting convergence test function failed");
    return (IDA_ILL_INPUT);
  }

  /* set max allowed nonlinear iterations */
  retval = SUNNonlinSolSetMaxIters(IDA_mem->NLSstg, MAXIT);
  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Setting maximum number of nonlinear iterations failed");
    return (IDA_ILL_INPUT);
  }

  /* create vector wrappers if necessary */
  if (IDA_mem->stgMallocDone == SUNFALSE)
  {
    IDA_mem->ypredictStg = N_VNewEmpty_SensWrapper(IDA_mem->ida_Ns,
                                                   IDA_mem->ida_sunctx);
    if (IDA_mem->ypredictStg == NULL)
    {
      IDAProcessError(IDA_mem, IDA_MEM_FAIL, __LINE__, __func__, __FILE__,
                      MSG_MEM_FAIL);
      return (IDA_MEM_FAIL);
    }

    IDA_mem->ycorStg = N_VNewEmpty_SensWrapper(IDA_mem->ida_Ns,
                                               IDA_mem->ida_sunctx);
    if (IDA_mem->ycorStg == NULL)
    {
      N_VDestroy(IDA_mem->ypredictStg);
      IDAProcessError(IDA_mem, IDA_MEM_FAIL, __LINE__, __func__, __FILE__,
                      MSG_MEM_FAIL);
      return (IDA_MEM_FAIL);
    }

    IDA_mem->ewtStg = N_VNewEmpty_SensWrapper(IDA_mem->ida_Ns,
                                              IDA_mem->ida_sunctx);
    if (IDA_mem->ewtStg == NULL)
    {
      N_VDestroy(IDA_mem->ypredictStg);
      N_VDestroy(IDA_mem->ycorStg);
      IDAProcessError(IDA_mem, IDA_MEM_FAIL, __LINE__, __func__, __FILE__,
                      MSG_MEM_FAIL);
      return (IDA_MEM_FAIL);
    }

    IDA_mem->stgMallocDone = SUNTRUE;
  }

  /* attach vectors to vector wrappers */
  for (is = 0; is < IDA_mem->ida_Ns; is++)
  {
    NV_VEC_SW(IDA_mem->ypredictStg, is) = IDA_mem->ida_yySpredict[is];
    NV_VEC_SW(IDA_mem->ycorStg, is)     = IDA_mem->ida_eeS[is];
    NV_VEC_SW(IDA_mem->ewtStg, is)      = IDA_mem->ida_ewtS[is];
  }

  return (IDA_SUCCESS);
}

/* -----------------------------------------------------------------------------
 * Private functions
 * ---------------------------------------------------------------------------*/

int idaNlsInitSensStg(IDAMem IDA_mem)
{
  int retval;

  /* set the linear solver setup wrapper function */
  if (IDA_mem->ida_lsetup)
  {
    retval = SUNNonlinSolSetLSetupFn(IDA_mem->NLSstg, idaNlsLSetupSensStg);
  }
  else { retval = SUNNonlinSolSetLSetupFn(IDA_mem->NLSstg, NULL); }

  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Setting the linear solver setup function failed");
    return (IDA_NLS_INIT_FAIL);
  }

  /* set the linear solver solve wrapper function */
  if (IDA_mem->ida_lsolve)
  {
    retval = SUNNonlinSolSetLSolveFn(IDA_mem->NLSstg, idaNlsLSolveSensStg);
  }
  else { retval = SUNNonlinSolSetLSolveFn(IDA_mem->NLSstg, NULL); }

  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    "Setting linear solver solve function failed");
    return (IDA_NLS_INIT_FAIL);
  }

  /* initialize nonlinear solver */
  retval = SUNNonlinSolInitialize(IDA_mem->NLSstg);

  if (retval != IDA_SUCCESS)
  {
    IDAProcessError(IDA_mem, IDA_ILL_INPUT, __LINE__, __func__, __FILE__,
                    MSG_NLS_INIT_FAIL);
    return (IDA_NLS_INIT_FAIL);
  }

  return (IDA_SUCCESS);
}

static int idaNlsLSetupSensStg(SUNDIALS_MAYBE_UNUSED sunbooleantype jbad,
                               sunbooleantype* jcur, void* ida_mem)
{
  IDAMem IDA_mem;
  int retval;

  if (ida_mem == NULL)
  {
    IDAProcessError(NULL, IDA_MEM_NULL, __LINE__, __func__, __FILE__, MSG_NO_MEM);
    return (IDA_MEM_NULL);
  }
  IDA_mem = (IDAMem)ida_mem;

  IDA_mem->ida_nsetupsS++;

  retval = IDA_mem->ida_lsetup(IDA_mem, IDA_mem->ida_yy, IDA_mem->ida_yp,
                               IDA_mem->ida_delta, IDA_mem->ida_tmpS1,
                               IDA_mem->ida_tmpS2, IDA_mem->ida_tmpS3);

  /* update Jacobian status */
  *jcur = SUNTRUE;

  /* update convergence test constants */
  IDA_mem->ida_cjold   = IDA_mem->ida_cj;
  IDA_mem->ida_cjratio = ONE;
  IDA_mem->ida_ss      = TWENTY;
  IDA_mem->ida_ssS     = TWENTY;

  if (retval < 0) { return (IDA_LSETUP_FAIL); }
  if (retval > 0) { return (IDA_LSETUP_RECVR); }

  return (IDA_SUCCESS);
}

static int idaNlsLSolveSensStg(N_Vector deltaStg, void* ida_mem)
{
  IDAMem IDA_mem;
  int retval, is;

  if (ida_mem == NULL)
  {
    IDAProcessError(NULL, IDA_MEM_NULL, __LINE__, __func__, __FILE__, MSG_NO_MEM);
    return (IDA_MEM_NULL);
  }
  IDA_mem = (IDAMem)ida_mem;

  for (is = 0; is < IDA_mem->ida_Ns; is++)
  {
    retval = IDA_mem->ida_lsolve(IDA_mem, NV_VEC_SW(deltaStg, is),
                                 IDA_mem->ida_ewtS[is], IDA_mem->ida_yy,
                                 IDA_mem->ida_yp, IDA_mem->ida_delta);

    if (retval < 0) { return (IDA_LSOLVE_FAIL); }
    if (retval > 0) { return (IDA_LSOLVE_RECVR); }
  }

  return (IDA_SUCCESS);
}

static int idaNlsResidualSensStg(N_Vector ycorStg, N_Vector resStg, void* ida_mem)
{
  IDAMem IDA_mem;
  int retval;

  if (ida_mem == NULL)
  {
    IDAProcessError(NULL, IDA_MEM_NULL, __LINE__, __func__, __FILE__, MSG_NO_MEM);
    return (IDA_MEM_NULL);
  }
  IDA_mem = (IDAMem)ida_mem;

  /* update yS and ypS based on the current correction */
  N_VLinearSumVectorArray(IDA_mem->ida_Ns, ONE, IDA_mem->ida_yySpredict, ONE,
                          NV_VECS_SW(ycorStg), IDA_mem->ida_yyS);
  N_VLinearSumVectorArray(IDA_mem->ida_Ns, ONE, IDA_mem->ida_ypSpredict,
                          IDA_mem->ida_cj, NV_VECS_SW(ycorStg), IDA_mem->ida_ypS);

  /* evaluate sens residual */
  retval = IDA_mem->ida_resS(IDA_mem->ida_Ns, IDA_mem->ida_tn, IDA_mem->ida_yy,
                             IDA_mem->ida_yp, IDA_mem->ida_delta,
                             IDA_mem->ida_yyS, IDA_mem->ida_ypS,
                             NV_VECS_SW(resStg), IDA_mem->ida_user_dataS,
                             IDA_mem->ida_tmpS1, IDA_mem->ida_tmpS2,
                             IDA_mem->ida_tmpS3);

  /* increment the number of sens residual evaluations */
  IDA_mem->ida_nrSe++;

  if (retval < 0) { return (IDA_SRES_FAIL); }
  if (retval > 0) { return (IDA_SRES_RECVR); }

  return (IDA_SUCCESS);
}

static int idaNlsConvTestSensStg(SUNNonlinearSolver NLS,
                                 SUNDIALS_MAYBE_UNUSED N_Vector ycor,
                                 N_Vector del, sunrealtype tol, N_Vector ewt,
                                 void* ida_mem)
{
  IDAMem IDA_mem;
  int m, retval;
  sunrealtype delnrm;
  sunrealtype rate;

  if (ida_mem == NULL)
  {
    IDAProcessError(NULL, IDA_MEM_NULL, __LINE__, __func__, __FILE__, MSG_NO_MEM);
    return (IDA_MEM_NULL);
  }
  IDA_mem = (IDAMem)ida_mem;

  /* compute the norm of the correction */
  delnrm = N_VWrmsNorm(del, ewt);

  /* get the current nonlinear solver iteration count */
  retval = SUNNonlinSolGetCurIter(NLS, &m);
  if (retval != IDA_SUCCESS) { return (IDA_MEM_NULL); }

  /* test for convergence, first directly, then with rate estimate. */
  if (m == 0)
  {
    IDA_mem->ida_oldnrm = delnrm;
    if (delnrm <= IDA_mem->ida_toldel) { return (SUN_SUCCESS); }
  }
  else
  {
    rate = SUNRpowerR(delnrm / IDA_mem->ida_oldnrm, ONE / m);
    if (rate > RATEMAX) { return (SUN_NLS_CONV_RECVR); }
    IDA_mem->ida_ssS = rate / (ONE - rate);
  }

  if (IDA_mem->ida_ssS * delnrm <= tol) { return (SUN_SUCCESS); }

  /* not yet converged */
  return (SUN_NLS_CONTINUE);
}
